{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# default_exp loss_strategy.base\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"-1\"\n",
    "from nbdev.showdoc import show_doc"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# LossCombinationStrategyBase"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "from collections import deque\n",
    "from typing import Dict, List\n",
    "\n",
    "import tensorflow as tf\n",
    "from m3tl.utils import get_phase\n",
    "from tensorflow.python.util.nest import (flatten,\n",
    "                                         flatten_with_joined_string_paths)\n",
    "\n",
    "\n",
    "class LossCombinationStrategyBase(tf.keras.Model):\n",
    "    def __init__(self, params, name:str, *args, **kwargs):\n",
    "        super(LossCombinationStrategyBase, self).__init__(name, *args, **kwargs)\n",
    "        self.params = params\n",
    "        self.problem_list = self.params.problem_list\n",
    "        self.hist_loss_dict = deque(maxlen=100)\n",
    "        self.hist_metric_dict = deque(maxlen=100)\n",
    "    \n",
    "    def extract_loss_metric_dict_from_history(self, \n",
    "                                            history: tf.keras.callbacks.History,\n",
    "                                            structure: dict,\n",
    "                                            prefix='val_') -> dict:\n",
    "        history: Dict[str, float] = history.history\n",
    "\n",
    "        # metrics from validation set starts with val\n",
    "        if prefix:\n",
    "            if prefix != 'val_':\n",
    "                raise ValueError('prefix should either be \"val_\" or None')\n",
    "            history = {k.replace(prefix, ''): v for k, v in history.items() if k.startswith(prefix)}\n",
    "\n",
    "        \n",
    "\n",
    "        # get structure path\n",
    "        structure_path = [p for p, _ in flatten_with_joined_string_paths(structure)]\n",
    "        # make flat history and pack\n",
    "        flat_history = [history[p] for p in structure_path]\n",
    "        history = tf.nest.pack_sequence_as(structure=structure, flat_sequence=flat_history)\n",
    "\n",
    "        return history\n",
    "\n",
    "    def get_all_losses(self, current_loss_dict: dict) -> List[tf.Tensor]:\n",
    "        return flatten(current_loss_dict)\n",
    "\n",
    "    def get_problem_loss(self, current_loss_dict:dict, problem: str) -> List[tf.Tensor]:\n",
    "        flatten_loss_with_path = flatten_with_joined_string_paths(current_loss_dict)\n",
    "        return [v for p, v in flatten_loss_with_path if problem in p]\n",
    "\n",
    "    def call(self, \n",
    "            current_loss_dict: dict,\n",
    "            current_metric_dict: dict,\n",
    "            history: tf.keras.callbacks.History):\n",
    "        raise NotImplementedError\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class SumLossCombination(LossCombinationStrategyBase):\n",
    "    def __init__(self, params, name: str, *args, **kwargs):\n",
    "        super().__init__(params, name, *args, **kwargs)\n",
    "    \n",
    "    def call(self, \n",
    "            current_loss_dict: dict,\n",
    "            current_metric_dict: dict,\n",
    "            history: tf.keras.callbacks.History):\n",
    "        mode = get_phase()\n",
    "        # total losses\n",
    "        losses = self.get_all_losses(current_loss_dict)\n",
    "        return losses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2021-06-12 22:06:32.702 | INFO     | m3tl.base_params:register_multiple_problems:526 - Adding new problem weibo_fake_ner, problem type: seq_tag\n",
      "2021-06-12 22:06:32.702 | INFO     | m3tl.base_params:register_multiple_problems:526 - Adding new problem weibo_fake_multi_cls, problem type: multi_cls\n",
      "2021-06-12 22:06:32.703 | INFO     | m3tl.base_params:register_multiple_problems:526 - Adding new problem weibo_fake_cls, problem type: cls\n",
      "2021-06-12 22:06:32.703 | INFO     | m3tl.base_params:register_multiple_problems:526 - Adding new problem weibo_masklm, problem type: masklm\n",
      "2021-06-12 22:06:32.703 | INFO     | m3tl.base_params:register_multiple_problems:526 - Adding new problem weibo_fake_regression, problem type: regression\n",
      "2021-06-12 22:06:32.704 | INFO     | m3tl.base_params:register_multiple_problems:526 - Adding new problem weibo_fake_vector_fit, problem type: vector_fit\n",
      "2021-06-12 22:06:32.704 | INFO     | m3tl.base_params:register_multiple_problems:526 - Adding new problem weibo_premask_mlm, problem type: premask_mlm\n",
      "2021-06-12 22:06:32.705 | INFO     | m3tl.base_params:register_multiple_problems:526 - Adding new problem fake_contrastive_learning, problem type: contrastive_learning\n",
      "2021-06-12 22:06:32.705 | WARNING  | m3tl.base_params:assign_problem:620 - base_dir and dir_name arguments will be deprecated in the future. Please use model_dir instead.\n",
      "2021-06-12 22:06:32.706 | WARNING  | m3tl.base_params:prepare_dir:364 - bert_config not exists. will load model from huggingface checkpoint.\n",
      "2021-06-12 22:06:38.983 | WARNING  | m3tl.read_write_tfrecord:chain_processed_data:248 - Chaining problems with & may consume a lot of memory if data is not pyspark RDD.\n",
      "2021-06-12 22:06:38.999 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:134 - Writing /tmp/tmpnn8s6_dc/weibo_fake_cls_weibo_fake_ner_weibo_fake_regression_weibo_fake_vector_fit/train_00000.tfrecord\n",
      "2021-06-12 22:06:39.075 | WARNING  | m3tl.read_write_tfrecord:chain_processed_data:248 - Chaining problems with & may consume a lot of memory if data is not pyspark RDD.\n",
      "2021-06-12 22:06:39.090 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:134 - Writing /tmp/tmpnn8s6_dc/weibo_fake_cls_weibo_fake_ner_weibo_fake_regression_weibo_fake_vector_fit/eval_00000.tfrecord\n",
      "2021-06-12 22:06:39.122 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:134 - Writing /tmp/tmpnn8s6_dc/weibo_fake_multi_cls/train_00000.tfrecord\n",
      "2021-06-12 22:06:39.147 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:134 - Writing /tmp/tmpnn8s6_dc/weibo_fake_multi_cls/eval_00000.tfrecord\n",
      "2021-06-12 22:06:39.223 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:134 - Writing /tmp/tmpnn8s6_dc/weibo_masklm/train_00000.tfrecord\n",
      "2021-06-12 22:06:39.272 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:134 - Writing /tmp/tmpnn8s6_dc/weibo_masklm/eval_00000.tfrecord\n",
      "2021-06-12 22:06:39.336 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:134 - Writing /tmp/tmpnn8s6_dc/weibo_premask_mlm/train_00000.tfrecord\n",
      "2021-06-12 22:06:39.398 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:134 - Writing /tmp/tmpnn8s6_dc/weibo_premask_mlm/eval_00000.tfrecord\n",
      "2021-06-12 22:06:39.416 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:134 - Writing /tmp/tmpnn8s6_dc/fake_contrastive_learning/train_00000.tfrecord\n",
      "2021-06-12 22:06:39.433 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:134 - Writing /tmp/tmpnn8s6_dc/fake_contrastive_learning/eval_00000.tfrecord\n",
      "2021-06-12 22:06:40.458 | INFO     | m3tl.input_fn:train_eval_input_fn:56 - sampling weights: \n",
      "2021-06-12 22:06:40.459 | INFO     | m3tl.input_fn:train_eval_input_fn:57 - {\n",
      "    \"weibo_fake_cls_weibo_fake_ner_weibo_fake_regression_weibo_fake_vector_fit\": 0.20408163265306123,\n",
      "    \"weibo_fake_multi_cls\": 0.20408163265306123,\n",
      "    \"weibo_masklm\": 0.1836734693877551,\n",
      "    \"weibo_premask_mlm\": 0.20408163265306123,\n",
      "    \"fake_contrastive_learning\": 0.20408163265306123\n",
      "}\n",
      "2021-06-12 22:06:41.120 | INFO     | m3tl.input_fn:train_eval_input_fn:56 - sampling weights: \n",
      "2021-06-12 22:06:41.120 | INFO     | m3tl.input_fn:train_eval_input_fn:57 - {\n",
      "    \"weibo_fake_cls_weibo_fake_ner_weibo_fake_regression_weibo_fake_vector_fit\": 0.20408163265306123,\n",
      "    \"weibo_fake_multi_cls\": 0.20408163265306123,\n",
      "    \"weibo_masklm\": 0.1836734693877551,\n",
      "    \"weibo_premask_mlm\": 0.20408163265306123,\n",
      "    \"fake_contrastive_learning\": 0.20408163265306123\n",
      "}\n",
      "404 Client Error: Not Found for url: https://huggingface.co/voidful/albert_chinese_tiny/resolve/main/tf_model.h5\n",
      "Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFAlbertModel: ['predictions.LayerNorm.bias', 'predictions.dense.weight', 'predictions.LayerNorm.weight', 'predictions.bias', 'predictions.decoder.weight', 'predictions.decoder.bias', 'predictions.dense.bias']\n",
      "- This IS expected if you are initializing TFAlbertModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing TFAlbertModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "All the weights of TFAlbertModel were initialized from the PyTorch model.\n",
      "If your task is similar to the task the model of the checkpoint was trained on, you can already use TFAlbertModel for predictions without further training.\n",
      "2021-06-12 22:06:45.703 | CRITICAL | m3tl.embedding_layer.base:__init__:58 - Modal Type id mapping: \n",
      " {\n",
      "    \"array\": 0,\n",
      "    \"cate\": 1,\n",
      "    \"text\": 2\n",
      "}\n",
      "2021-06-12 22:06:45.823 | WARNING  | m3tl.problem_types.masklm:__init__:41 - Share embedding is enabled but hidden_size != embedding_size\n",
      "2021-06-12 22:06:45.853 | WARNING  | m3tl.problem_types.contrastive_learning:get_contrastive_learning_model:86 - None not match any contrastive learning model, using SimCSE\n",
      "2021-06-12 22:06:45.897 | CRITICAL | m3tl.model_fn:compile:271 - Initial lr: 0.0\n",
      "2021-06-12 22:06:45.897 | CRITICAL | m3tl.model_fn:compile:272 - Train steps: 0\n",
      "2021-06-12 22:06:45.898 | CRITICAL | m3tl.model_fn:compile:273 - Warmup steps: 0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/2\n",
      "WARNING: AutoGraph could not transform <bound method BertMultiTaskBody.call of <m3tl.model_fn.BertMultiTaskBody object at 0x7fd2d617b290>> and will run it as-is.\n",
      "Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.\n",
      "Cause: invalid value for \"node\": expected \"ast.AST\", got \"<class 'NoneType'>\"; to visit lists of nodes, use \"visit_block\" instead\n",
      "To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n",
      "WARNING: AutoGraph could not transform <bound method Socket.send of <zmq.sugar.socket.Socket object at 0x7fd3e16839f0>> and will run it as-is.\n",
      "Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.\n",
      "Cause: module, class, method, function, traceback, frame, or code object was expected, got cython_function_or_method\n",
      "To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The parameters `output_attentions`, `output_hidden_states` and `use_cache` cannot be updated when calling a model.They have to be set to True/False in the config object (i.e.: `config=XConfig.from_pretrained('name', output_attentions=True)`).\n",
      "The parameter `return_dict` cannot be set in graph mode and will always be set to `True`.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - ETA: 0s - mean_acc: 2.0046 - fake_contrastive_learning_acc: 0.1667 - weibo_fake_cls_acc: 0.4286 - weibo_fake_ner_acc: 0.1688 - weibo_fake_regression_neg_mse: -1.2120 - weibo_fake_vector_fit_cos_sim: -0.3722 - BertMultiTaskTop/fake_contrastive_learning/simcse/losses/0: 2.4941 - BertMultiTaskTop/weibo_fake_cls/losses/0: 1.0909 - BertMultiTaskTop/weibo_fake_multi_cls/losses/0: 0.4924 - BertMultiTaskTop/weibo_fake_ner/losses/0: 1.4930 - BertMultiTaskTop/weibo_fake_regression/losses/0: 1.2120 - BertMultiTaskTop/weibo_fake_vector_fit/losses/0: 0.3722 - BertMultiTaskTop/weibo_masklm/losses/0: 9.9305 - BertMultiTaskTop/weibo_premask_mlm/losses/0: 9.7947"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The parameters `output_attentions`, `output_hidden_states` and `use_cache` cannot be updated when calling a model.They have to be set to True/False in the config object (i.e.: `config=XConfig.from_pretrained('name', output_attentions=True)`).\n",
      "The parameter `return_dict` cannot be set in graph mode and will always be set to `True`.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 24s 24s/step - mean_acc: 2.0046 - fake_contrastive_learning_acc: 0.1667 - weibo_fake_cls_acc: 0.4286 - weibo_fake_ner_acc: 0.1688 - weibo_fake_regression_neg_mse: -1.2120 - weibo_fake_vector_fit_cos_sim: -0.3722 - BertMultiTaskTop/fake_contrastive_learning/simcse/losses/0: 2.4941 - BertMultiTaskTop/weibo_fake_cls/losses/0: 1.0909 - BertMultiTaskTop/weibo_fake_multi_cls/losses/0: 0.4924 - BertMultiTaskTop/weibo_fake_ner/losses/0: 1.4930 - BertMultiTaskTop/weibo_fake_regression/losses/0: 1.2120 - BertMultiTaskTop/weibo_fake_vector_fit/losses/0: 0.3722 - BertMultiTaskTop/weibo_masklm/losses/0: 9.9305 - BertMultiTaskTop/weibo_premask_mlm/losses/0: 9.7947 - val_loss: 37.0011 - val_mean_acc: 0.2857 - val_fake_contrastive_learning_acc: 0.1429 - val_weibo_fake_cls_acc: 0.5714 - val_weibo_fake_ner_acc: 0.1429 - val_weibo_fake_regression_neg_mse: -1.0452 - val_weibo_fake_vector_fit_cos_sim: -0.3764 - val_BertMultiTaskTop/fake_contrastive_learning/simcse/losses/0: 13.1461 - val_BertMultiTaskTop/weibo_fake_cls/losses/0: 0.8110 - val_BertMultiTaskTop/weibo_fake_multi_cls/losses/0: 0.4187 - val_BertMultiTaskTop/weibo_fake_ner/losses/0: 1.5319 - val_BertMultiTaskTop/weibo_fake_regression/losses/0: 1.0452 - val_BertMultiTaskTop/weibo_fake_vector_fit/losses/0: 0.3764 - val_BertMultiTaskTop/weibo_masklm/losses/0: 9.9035 - val_BertMultiTaskTop/weibo_premask_mlm/losses/0: 9.7685\n",
      "Epoch 2/2\n",
      "1/1 [==============================] - 1s 977ms/step - mean_acc: 1.9844 - fake_contrastive_learning_acc: 0.2500 - weibo_fake_cls_acc: 0.6667 - weibo_fake_ner_acc: 0.2424 - weibo_fake_regression_neg_mse: -0.7024 - weibo_fake_vector_fit_cos_sim: -0.1600 - BertMultiTaskTop/fake_contrastive_learning/simcse/losses/0: 2.3416 - BertMultiTaskTop/weibo_fake_cls/losses/0: 0.6408 - BertMultiTaskTop/weibo_fake_multi_cls/losses/0: 0.4468 - BertMultiTaskTop/weibo_fake_ner/losses/0: 1.4879 - BertMultiTaskTop/weibo_fake_regression/losses/0: 0.7024 - BertMultiTaskTop/weibo_fake_vector_fit/losses/0: 0.1600 - BertMultiTaskTop/weibo_masklm/losses/0: 9.9203 - BertMultiTaskTop/weibo_premask_mlm/losses/0: 9.8009 - val_loss: 35.7190 - val_mean_acc: 0.2525 - val_fake_contrastive_learning_acc: 0.1667 - val_weibo_fake_cls_acc: 0.3750 - val_weibo_fake_ner_acc: 0.2159 - val_weibo_fake_regression_neg_mse: -1.0292 - val_weibo_fake_vector_fit_cos_sim: -0.3355 - val_BertMultiTaskTop/fake_contrastive_learning/simcse/losses/0: 11.9251 - val_BertMultiTaskTop/weibo_fake_cls/losses/0: 0.9174 - val_BertMultiTaskTop/weibo_fake_multi_cls/losses/0: 0.4063 - val_BertMultiTaskTop/weibo_fake_ner/losses/0: 1.4417 - val_BertMultiTaskTop/weibo_fake_regression/losses/0: 1.0292 - val_BertMultiTaskTop/weibo_fake_vector_fit/losses/0: 0.3355 - val_BertMultiTaskTop/weibo_masklm/losses/0: 9.9006 - val_BertMultiTaskTop/weibo_premask_mlm/losses/0: 9.7633\n"
     ]
    }
   ],
   "source": [
    "from m3tl.test_base import TestBase\n",
    "from m3tl.special_tokens import TRAIN\n",
    "from m3tl.utils import create_dict_from_nested_model\n",
    "\n",
    "tb = TestBase()\n",
    "tb.test_loss_combination_strategy(loss_combination_strategy_name='sum')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "defaultdict(list,\n",
       "            {'BertMultiTaskTop': defaultdict(list,\n",
       "                         {'fake_contrastive_learning': defaultdict(list,\n",
       "                                      {'simcse': defaultdict(list,\n",
       "                                                   {'losses': [[2.494117498397827,\n",
       "                                                      2.3415939807891846]]})}),\n",
       "                          'weibo_fake_cls': defaultdict(list,\n",
       "                                      {'losses': [[1.0908699035644531,\n",
       "                                         0.6408037543296814]]}),\n",
       "                          'weibo_fake_multi_cls': defaultdict(list,\n",
       "                                      {'losses': [[0.4924350678920746,\n",
       "                                         0.44675901532173157]]}),\n",
       "                          'weibo_fake_ner': defaultdict(list,\n",
       "                                      {'losses': [[1.4930245876312256,\n",
       "                                         1.4878641366958618]]}),\n",
       "                          'weibo_fake_regression': defaultdict(list,\n",
       "                                      {'losses': [[1.211979866027832,\n",
       "                                         0.7023975253105164]]}),\n",
       "                          'weibo_fake_vector_fit': defaultdict(list,\n",
       "                                      {'losses': [[0.3722432553768158,\n",
       "                                         0.1600482314825058]]}),\n",
       "                          'weibo_masklm': defaultdict(list,\n",
       "                                      {'losses': [[9.930464744567871,\n",
       "                                         9.920283317565918]]}),\n",
       "                          'weibo_premask_mlm': defaultdict(list,\n",
       "                                      {'losses': [[9.794666290283203,\n",
       "                                         9.800870895385742]]})})})"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_instance = LossCombinationStrategyBase(tb.params, 'test')\n",
    "# validation losses\n",
    "test_instance.extract_loss_metric_dict_from_history(history=tb.all_model.history, structure = create_dict_from_nested_model(tb.all_model))\n",
    "# training losses\n",
    "test_instance.extract_loss_metric_dict_from_history(history=tb.all_model.history, structure = create_dict_from_nested_model(tb.all_model), prefix='')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.10 64-bit ('base': conda)",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
